package com.wujunshen.chess.engine;

import static com.wujunshen.chess.common.Constants.*;

import com.github.bhlangonijr.chesslib.Board;
import com.github.bhlangonijr.chesslib.PieceType;
import com.github.bhlangonijr.chesslib.move.Move;
import com.wujunshen.chess.StartPosition;
import com.wujunshen.chess.common.ChessProperties;
import java.util.ArrayList;
import java.util.List;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.tuple.Triple;

/**
 * @author wujunshen
 */
@NoArgsConstructor
public class SearchEngine {
  public Triple<PieceType, Move, Float> search(Node rootNode, ChessProperties chessProperties) {
    List<Triple<PieceType, Move, Float>> moves = new ArrayList<>();
    Triple<PieceType, Move, Float> checkmateMove = null;
    Triple<PieceType, Move, Float> bestMove;

    for (int depth = 1; depth <= chessProperties.getDepth(); depth++) {
      System.out.println("batch_run depth=" + depth);

      int cutDepth = depth - chessProperties.getLeafDepth();
      if (cutDepth >= 0) {
        int cutWidth = (int) Math.ceil(chessProperties.getWidth() / Math.pow(2, cutDepth));
        cutWidth = Math.max(cutWidth, CUT_WIDTH);

        cutChildren(rootNode, chessProperties.getWidth(), cutDepth, cutWidth);

        System.out.println("cut_children  cut_depth=" + cutDepth + " cut_width=" + cutWidth);
      }

      if (checkmateMove == null) {
        checkmateMove = checkmate(rootNode.getChildrenNodes());
      }
      bestMove = bestMoves(rootNode.getChildrenNodes());

      moves.add(checkmateMove != null ? checkmateMove : bestMove);
    }

    moves.stream()
        .sorted((i1, i2) -> Float.compare(i1.getRight(), i2.getRight()))
        .toList()
        .reversed();

    if (moves.isEmpty() || moves.getFirst() == null) {
      return null;
    }

    return moves.getFirst();
  }

  private Triple<PieceType, Move, Float> checkmate(List<Node> childrenNodes) {
    Node highValueNode;
    Float highValue;
    PieceType pieceType;
    Move checkmateMove;

    childrenNodes =
        childrenNodes.stream()
            .filter(i -> WIN.equals(i.getResult()))
            .sorted((i1, i2) -> Float.compare(i1.getValue(), i2.getValue()))
            .toList()
            .reversed();

    if (childrenNodes.isEmpty()) {
      return null;
    }

    highValueNode = childrenNodes.getFirst();

    if (highValueNode == null) {
      return null;
    }

    Board childBoard = highValueNode.getBoard();
    highValue = highValueNode.getValue();
    checkmateMove = childBoard.undoMove();

    pieceType = childBoard.getPiece(checkmateMove.getFrom()).getPieceType();
    childBoard.doMove(checkmateMove);

    return Triple.of(pieceType, checkmateMove, highValue);
  }

  private Triple<PieceType, Move, Float> bestMoves(List<Node> childrenNodes) {
    Node highValueNode;
    Float highValue;
    PieceType pieceType;
    Move bestMove;

    childrenNodes =
        childrenNodes.stream()
            .sorted((i1, i2) -> Float.compare(i1.getValue(), i2.getValue()))
            .toList()
            .reversed();

    if (childrenNodes.isEmpty()) {
      return null;
    }

    for (int i = 0; i < Math.min(childrenNodes.size(), CUT_WIDTH); i++) {
      Node childNode = childrenNodes.get(i);
      Board childBoard = childNode.getBoard();
      Move move = childBoard.undoMove();

      PieceType fromPieceType = childBoard.getPiece(move.getFrom()).getPieceType();

      childBoard.doMove(move);

      System.out.println(fromPieceType + " " + move + " " + childNode.getValue());
    }

    highValueNode = childrenNodes.getFirst();

    if (highValueNode == null) {
      return null;
    }

    Board highValueBoard = highValueNode.getBoard();
    highValue = highValueNode.getValue();
    bestMove = highValueBoard.undoMove();

    pieceType = highValueBoard.getPiece(bestMove.getFrom()).getPieceType();
    highValueBoard.doMove(bestMove);

    return Triple.of(pieceType, bestMove, highValue);
  }

  private void cutChildren(Node node, int width, int cutDepth, int cutWidth) {
    List<Node> childrenNodes = node.getChildrenNodes();
    if (cutDepth < 0 || childrenNodes.isEmpty()) {
      return;
    }

    node.setChildrenNodes(
        childrenNodes.stream()
            .sorted((i1, i2) -> Float.compare(i1.getValue(), i2.getValue()))
            .toList());

    childrenNodes = node.getChildrenNodes();

    if (node.getBoard().getSideToMove() == StartPosition.initSide) {
      childrenNodes.reversed();
    }

    node.setChildrenNodes(childrenNodes.subList(0, Math.min(cutWidth, childrenNodes.size())));

    cutWidth = cutWidth * 2;
    cutWidth = Math.min(cutWidth, width);

    childrenNodes = node.getChildrenNodes();
    for (Node i : childrenNodes) {
      cutChildren(i, width, cutDepth - 1, cutWidth);
    }
  }
}
